Skip to content

[JAX] Remove GSPMD tests + adding guards and warning msg for GSPMD rules#2702

Merged
phu0ngng merged 2 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd
Mar 3, 2026
Merged

[JAX] Remove GSPMD tests + adding guards and warning msg for GSPMD rules#2702
phu0ngng merged 2 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd

Conversation

@phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Feb 24, 2026

Description

GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner.
This commit removes all GSPMD-related tests. The GSPMD sharding propagation rules will be kept for another 3 months until June 2026.

GSPMD rules with the existing primitives will work with older JAX versions (until 0.9.1) with a printed warning.
For the incoming primitives that do not have the GSPMD rules, if users attempt to use them with GSPMD, an error will be raised before it crashes.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 24, 2026

Greptile Summary

This PR deprecates GSPMD sharding propagation in the JAX backend of TransformerEngine in preparation for its full removal in June 2026. It removes all GSPMD-specific test variants across distributed test suites and example encoders, strips the use_shardy toggle from test helpers and CLI parsers, and adds a version-guarded registration path in BasePrimitive so that:

  • On JAX ≤ 0.9.1, infer_sharding_from_operands is still registered (for backwards compatibility) with a one-shot DeprecationWarning.
  • On JAX > 0.9.1, the infer_sharding_from_operands kwarg is omitted entirely from def_partition(), since the API was removed from JAX.
  • New primitives that never implemented infer_sharding_from_operands inherit a base-class version that raises NotImplementedError with a clear message if GSPMD is attempted.

Key changes:

  • transformer_engine/jax/cpp_extensions/base.py: Adds _JAX_GSPMD_SUPPORTED version flag, converts infer_sharding_from_operands from an @abstractmethod to a default @classmethod that raises NotImplementedError, adds _warn_gspmd_deprecation_once() helper, and conditionally passes GSPMD kwargs in register_primitive().
  • All distributed test files (test_distributed_fused_attn.py, test_distributed_layernorm.py, test_distributed_layernorm_mlp.py, test_distributed_permutation.py, test_distributed_softmax.py): Remove use_shardy parametrize, jax.config.update("jax_use_shardy_partitioner", ...) calls, and all *_shardy / *_gspmd test methods.
  • Encoder examples: Remove --enable-shardy CLI argument, enable_shardy exec helper parameter, and all corresponding *_shardy test cases from both Python files and the shell test runner.

Notable concern: The deprecation warning in _warn_gspmd_deprecation_once() is triggered at primitive registration time (i.e., at import time), not when GSPMD sharding propagation is actually invoked. This means users on JAX ≤ 0.9.1 who have already switched to Shardy will still receive the "Use it at your own risk" DeprecationWarning on every process startup, even though they are not using GSPMD.

Confidence Score: 4/5

  • Safe to merge; the GSPMD removal is clean and well-scoped, with one design concern about when the deprecation warning fires.
  • All test file cleanups are consistent and complete — no orphaned use_shardy references or dangling jax.config.update calls remain. The version guard in base.py is correctly implemented with PkgVersion and the fallback NotImplementedError for primitives without GSPMD rules is appropriate. Score is 4 rather than 5 because the deprecation warning fires at import time rather than at GSPMD invocation time, which will produce false-positive warnings for users on JAX ≤ 0.9.1 who have already adopted Shardy.
  • transformer_engine/jax/cpp_extensions/base.py — specifically the deprecation warning trigger timing in register_primitive().

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[register_primitive called at import time] --> B{JAX version\n≤ 0.9.1?}
    B -- Yes\n_JAX_GSPMD_SUPPORTED=True --> C{cls defines\ninfer_sharding_from_operands\nin own __dict__?}
    C -- Yes --> D[_warn_gspmd_deprecation_once\nDeprecationWarning issued\nat import time]
    D --> E[gspmd_kwargs = infer_sharding_from_operands\nfrom subclass]
    C -- No --> F[gspmd_kwargs = infer_sharding_from_operands\nfrom BasePrimitive base — raises NotImplementedError]
    E --> G[def_partition with partition +\nshardy_sharding_rule + infer_sharding_from_operands]
    F --> G
    B -- No\n_JAX_GSPMD_SUPPORTED=False --> H[gspmd_kwargs = empty dict]
    H --> I[def_partition with only\npartition + shardy_sharding_rule]
    G --> J[Primitive registered]
    I --> J

    style D fill:#ffcc00,color:#000
    style F fill:#ff9999,color:#000
Loading

Last reviewed commit: 69a71c6

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

23 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

19 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 2, 2026

/te-ci JAX L1

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once CI finishes, thanks for making this change!

This full removal is valid as users will be in one of the following situations, right?

a) The user is on the latest JAX version, in which case having these GSPMD functions around could cause errors since GSPMD has been removed, so removing this logic as we do in this PR is correct
b) The user is on an older version of JAX, in which case they can use GSPMD or Shardy. There have been updates from JAX itself about this transition so users have been aware they need to move to Shardy by March 2026, so in our case removing it should still be okay
c) The user is on a very old version of JAX (e.g. >1 year old), in which Shardy doesn't work or has bugs. In this case, they might also have other compatibility issues with TE/JAX, in which case they should update to a more recent JAX version

@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 2, 2026

Per offline discussion, we decided that GSPMD sharding propagation rules will be kept until June 2026.

This PR only removes the GSPMD tests.
GSPMD rules with the existing primitives will work with older JAX versions (until 0.9.1) with a printed warning.
For the incoming primitives that do not have the GSPMD rules, an Error will be raised only when the primitive is used with GSPMD.

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 2, 2026

/te-ci JAX L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (2)

transformer_engine/jax/cpp_extensions/base.py, line 242
gspmd_kwargs = {} branch is currently unreachable dead code

Since all current public JAX releases are in the 0.4.x0.5.x range, PkgVersion(jax.__version__) <= PkgVersion("0.9.1") is always True today. The else: gspmd_kwargs = {} branch — which is the actual Shardy-only code path — will never execute until a JAX version > 0.9.1 ships.

This is purely forward-looking, which is fine and the comment explains the intent. However, it's worth explicitly noting that until JAX reaches 0.9.2+, all registered primitives will still have infer_sharding_from_operands passed to def_partition(), resolved to the base-class implementation that raises NotImplementedError. Users on an old JAX where Shardy is not the default (i.e., jax_use_shardy_partitioner=False) who don't have the Shardy API available will hit the NotImplementedError at partitioning time rather than getting a more descriptive migration path.

Consider either:

  • Using a version boundary that already exists in released JAX (e.g., the version that made Shardy the default), so the branch isn't perpetually dead.
  • Or adding a comment explicitly acknowledging that the else branch is reserved for a future JAX release, to avoid confusion for future maintainers.

transformer_engine/jax/cpp_extensions/base.py, line 240
Deprecation warning never fires for updated primitives

The guard if "infer_sharding_from_operands" in cls.__dict__ only matches classes that define their own override. Since all infer_sharding_from_operands implementations were removed from cpp_extensions in this PR, the deprecation warning will never fire for any of those primitives — even for a user on JAX ≤ 0.9.1 who explicitly opts into GSPMD via jax_use_shardy_partitioner=False.

Instead, they'll silently receive the NotImplementedError from the base-class method at partition time. The NotImplementedError message is informative ("Please use Shardy partitioner instead"), so this is workable, but it's worth double-checking whether the intended audience for the deprecation warning is purely external subclasses (like triton extensions) that still define the method — in which case the current logic is correct and should be clarified with a comment.

@phu0ngng phu0ngng changed the title [JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests [JAX] Remove GSPMD tests + adding guards and warning msg for GSPMD rules Mar 3, 2026
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the PR description with the latest info? The PR itself LGTM, thanks!

@phu0ngng phu0ngng merged commit 39d249b into NVIDIA:main Mar 3, 2026
9 of 12 checks passed
@phu0ngng phu0ngng deleted the rm_gspmd branch March 3, 2026 16:45
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/base.py, line 240
The deprecation warning fires during register_primitive(), which happens at module import time. This means any user running JAX ≤ 0.9.1 will see the DeprecationWarning on every startup — even if they have already switched to Shardy and are not using GSPMD at all.

The PR description states "GSPMD rules with the existing primitives will work with older JAX versions (until 0.9.1) with a printed warning", which implies the warning should fire when GSPMD is actually invoked, not merely because GSPMD support code is present.

A more accurate approach would be to emit the warning inside a wrapper around the GSPMD-callable that is passed to def_partition, so it fires only when JAX actually calls infer_sharding_from_operands during sharding propagation:

if _JAX_GSPMD_SUPPORTED:
    original_fn = None
    if "infer_sharding_from_operands" in cls.__dict__:
        original_fn = cls.infer_sharding_from_operands
    
    if original_fn is not None:
        def _gspmd_wrapper(*args, **kwargs):
            _warn_gspmd_deprecation_once()
            return original_fn(*args, **kwargs)
        gspmd_kwargs = {"infer_sharding_from_operands": _gspmd_wrapper}
    else:
        gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands}
else:
    gspmd_kwargs = {}

As written, a user who has already migrated to Shardy on an older JAX release will see a misleading "Use it at your own risk" warning even though they are not using GSPMD.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants